"""
This code is written for Windows, and it uses pymetis package. In case it is being run in Linux, please change the
import statement from import pymetis as metis to import metis.
"""
import torch
import torch.nn.functional as F

import numpy as np
import pymetis as metis
import os


# %%
class PartitionTask(torch.nn.Module):

    def __init__(self, data, embedding_size, device):
        super(PartitionTask, self).__init__()
        self.name = 'graph partitioning'
        self.data = data
        self.dataset_names = [dataset.name for dataset in self.data.datasets]
        self.device = device

        self.num_parts = []
        for dataset_name in self.dataset_names:
            if dataset_name == 'citeseer':
                self.num_parts.append(400)  # was (1000) in AutoSSL source code
            elif dataset_name in ['photo', 'computers']:
                self.num_parts.append(100)
            elif data == 'wiki':
                self.num_parts.append(20)
            else:
                self.num_parts.append(400)

        self.pseudo_labels = self.create_pseudo_labels()
        self.predictor = [torch.nn.Linear(embedding_size, num_parts).to(self.device) for num_parts in self.num_parts]

    def get_loss(self, embeddings, dataset_name):
        index = self.dataset_names.index(dataset_name)
        embeddings = self.predictor[index](embeddings)
        output = F.log_softmax(embeddings, dim=1)
        loss = F.nll_loss(output, self.pseudo_labels[index])
        return loss

    def create_pseudo_labels(self):
        pseudo_labels = []
        i = 0
        for dataset_name in self.dataset_names:
            partition_file = './saved/' + dataset_name + '_partition_%s.npy' % self.num_parts[i]
            if not os.path.exists(partition_file):
                print('Performing graph partitioning with PyMetis...')
                edge_index = self.data.datasets[i].data.edge_index
                adj_list = [[] for _ in range(self.data.datasets[i].data.num_nodes)]
                for j in range(edge_index.shape[1]):
                    node_a, node_b = edge_index[0, j], edge_index[1, j]
                    adj_list[node_a].append(node_b.item())
                    adj_list[node_b].append(node_a.item())
                _, membership = metis.part_graph(adjacency=adj_list, nparts=self.num_parts[i])

                # sort membership values according to size of sub-graphs in descending order
                membership = np.array(membership)
                temp = []
                for part_ind in range(self.num_parts[i]):
                    temp.append(membership[membership == part_ind])
                partition_sizes = [len(_) for _ in temp]
                sort_ind = np.argsort(partition_sizes)[:: -1]
                adjusted_membership = np.zeros_like(membership)
                for _ in range(self.num_parts[i]):
                    mask = membership == sort_ind[_]
                    adjusted_membership[mask] = _

                np.save(partition_file, adjusted_membership)

                pseudo_labels.append(torch.LongTensor(membership).to(self.device))
            else:
                print(f'loading saved/{dataset_name}_partition_{self.num_parts[i]}.pt')
                partition_labels = np.load(partition_file)
                pseudo_labels.append(torch.LongTensor(partition_labels).to(self.device))
            i = i + 1
        return pseudo_labels
